#ifndef SIMPLE_COMPRESS_HUFFMAN_HPP
#define SIMPLE_COMPRESS_HUFFMAN_HPP

#include <iterator> // std::iterator_traits
#include <cstddef> // std::size_t
#include <array> // std::array
#include <vector> // std::vector
#include <utility> // std::pair std::move
#include <limits> // std::numeric_limits
#include <type_traits> // std::enable_if_t std::underlying_type_t std::make_unsigned_t std::is_unsigned_v std::conditional_t std::is_integral_v
#include <cassert> // assert
#include <algorithm> // std::transform std::swap

#include "simple/support/type_traits/is_template_instance.hpp" // support::is_template_instance_v

#include "hash_table.hpp" // hash_table
#include "bits.hpp" // bit_count bit_offset get_bits bits read_bits

namespace simple::compress
{

	// oof, so painful we still don't have this in std, especially with all the junk that went through over the years
	template <typename T, std::size_t Capacity>
	class static_vector
	{
		std::array<T, Capacity> array;
		typename std::array<T, Capacity>::iterator next;

		public:

		constexpr static_vector() : array(), next(array.begin()) {};

		constexpr auto begin() { return array.begin(); }
		constexpr auto begin() const { return array.begin(); }
		constexpr auto end() { return next; }
		constexpr auto end() const { return next; }

		constexpr bool empty() const
		{
			return next == array.begin();
		}

		constexpr bool full() const
		{
			return next == array.end();
		}

		constexpr void push_back(T element)
		{
			*next++ = std::move(element);
		}

		constexpr void pop_back() { --next; }

		constexpr auto& back() { return *next; }
		constexpr auto& back() const { return *next; }

	};

	template <typename SmallKey, typename Value, typename Enabled = void>
	class small_table;

	template <typename SmallKey, typename Value>
	class small_table<SmallKey, Value, std::enable_if_t<bit_count(SmallKey{}) <= 13 && std::is_unsigned_v<SmallKey>>>
	{
		std::array<Value, 1 << bit_count(SmallKey{})> table{};

		static constexpr auto get_bit_value(const SmallKey& key)
		{
			constexpr auto bc = bit_count(SmallKey{});
			constexpr auto bo = bit_offset(SmallKey{});
			auto value = get_bits(key);
			constexpr auto tc = std::numeric_limits<decltype(value)>::digits;
			value >>= tc - (bo + bc);
			return value;
		}

		template <typename Self>
		static constexpr auto& get(Self& self, const SmallKey& key)
		{
			auto offset = get_bit_value(key);
			assert(offset < self.table.size());
			return self.table[offset];
		}

		public:

		using key_type = SmallKey;
		using value_type = Value;

		constexpr Value& operator[](const SmallKey& key)
		{ return get(*this, key); }

		constexpr const Value& operator[](const SmallKey& key) const
		{ return get(*this, key); }

		template <typename F>
		constexpr void for_each(F&& f) const
		{
			constexpr auto bc = bit_count(SmallKey{});
			constexpr auto bo = bit_offset(SmallKey{});
			using key_value = decltype(get_bits(SmallKey{}));
			constexpr auto tc = std::numeric_limits<key_value>::digits;
			for(std::size_t i = 0; i != table.size(); ++i)
			{
				auto key = static_cast<key_value>(i);
				key <<= tc - bc - bo;
				f(std::pair{static_cast<SmallKey>(key), table[i]});
			}
		}

		template <typename F>
		// TODO: oof, need proper iterators, also this is used to read codes, and that's slow, prolly better walk the tree bit by bit
		constexpr void find_if(F&& f) const
		{
			constexpr auto bc = bit_count(SmallKey{});
			constexpr auto bo = bit_offset(SmallKey{});
			using key_value = decltype(get_bits(SmallKey{}));
			constexpr auto tc = std::numeric_limits<key_value>::digits;
			for(std::size_t i = 0; i != table.size(); ++i)
			{
				auto key = static_cast<key_value>(i);
				key <<= tc - bc - bo;
				if(f(std::pair{static_cast<SmallKey>(key), table[i]}))
					break;
			}
		}

	};

	template <typename BigKey, typename Value>
	class small_table<BigKey, Value, std::enable_if_t<(bit_count(BigKey{}) > 13)>>
	{
		hash_table<BigKey, std::pair<BigKey, Value>, (1<<16)> table;
		// TODO
	};

	template <typename T, typename Enabled = void>
	struct underlying_type;
	template <typename T>
	struct underlying_type<T, std::enable_if_t<std::is_integral_v<T>>>
	{ using type = T; };
	template <typename T>
	struct underlying_type<T, std::enable_if_t<not std::is_integral_v<T>>>
	{ using type = std::underlying_type_t<T>; };
	template <typename T>
	using underlying_type_t = typename underlying_type<T>::type;

	template <typename It>
	[[nodiscard]] constexpr auto huffman_code(It begin, It end)
	{
		using key_type = std::make_unsigned_t<underlying_type_t<typename std::iterator_traits<It>::value_type>>;
		small_table<key_type, std::size_t> counter{};
		small_table<key_type, bits<>> code{}; // TODO smaller key -> less bits
		std::conditional_t<bit_count(key_type{}) <= 13,
			static_vector<std::pair<key_type,key_type>, (1 << bit_count(key_type{}))>,
			std::vector<std::pair<key_type,key_type>>
		> hierarchy{};
		for(auto i = begin; i != end; ++i)
			++counter[static_cast<key_type>(*i)];

		std::array<std::pair<key_type, std::size_t>, 2> minmin;
		while(true)
		{
			minmin = decltype(minmin){};
			// NOTE: this could be partial_sort_copy(no_zeros(counter), minmin, second_less), but can't be bothered to write the necessary iterators atm
			counter.for_each([&minmin](auto kv)
			{
				// filter
				if(kv.second != 0)
				{
					// find a smaller value
					if(minmin[1].second == 0 || kv.second < minmin[1].second)
					{
						minmin[1] = kv;
						// keep em sorted
						if(minmin[0].second == 0 || minmin[1].second < minmin[0].second)
							std::swap(minmin[0], minmin[1]);
					}
				}
			});

			if(0 == minmin[1].second)
				break;

			// FIXME handle code length overflow
			code[minmin[0].first].insert(0);
			code[minmin[1].first].insert(1);
			for(auto& [symbol, parent] : hierarchy)
			{
				if(parent == minmin[0].first)
				{
					code[symbol].insert(0);
					parent = minmin[0].first;
				}
				if(parent == minmin[1].first)
				{
					code[symbol].insert(1);
					parent = minmin[0].first;
				}

			}

			counter[minmin[0].first] += counter[minmin[1].first];
			counter[minmin[1].first] = 0;

			hierarchy.push_back({minmin[1].first, minmin[0].first});
		}

		// special case: there is only one symbol
		if(0 != minmin[0].second && hierarchy.empty())
			code[minmin[0].first].insert(0);

		return code;
	}

	template <typename It, typename Out, typename Code>
	constexpr auto huffman_encode(const Code& code, It begin, It end, Out out)
	{
		return std::transform(begin, end, std::move(out),
			[&code](auto&& x) { return code[static_cast<typename Code::key_type>(x)]; });
	}

	template <typename Code, typename I, typename O,
		std::enable_if_t<support::is_template_instance_v<bit_iterator, I>>* = nullptr
	>
	constexpr auto huffman_decode(const Code& code, I i, O out, O out_end)
	{
		// FIXME: this is super slow
		// can sort codes, then each read bit will determine a partition, until we're left with one element
		using out_value = typename std::iterator_traits<O>::value_type;
		while(out != out_end)
		{
			code.find_if([&i, &out](auto&& kv)
			{
				if(bit_count(kv.second) != 0)
				{
					auto read = kv.second;
					auto next = read_bits(i, read);
					if(read == kv.second)
					{
						*out = static_cast<out_value>(kv.first);
						i = next;
						return true;
					}
				}
				return false;
			});
			++out;
		}
		return i;
	}

	template <typename Code, typename I, typename O,
		std::enable_if_t<not support::is_template_instance_v<bit_iterator, I>>* = nullptr
	>
	constexpr auto huffman_decode(const Code& code, I i, O out_begin, O out_end)
	{
		return huffman_decode(code, bit_iterator{i,0}, out_begin, out_end);
	}

} // namespace simple::compress

#endif /* end of include guard */
